#!/usr/bin/env python3

"""This script evaluates 'new state detection' models.
1. It runs each expert (every N step) to collect transitions, and save them

The new-state detection model expects to:
1. Expose itself to a set of states (e.g., s_t ~ expert-100)
2. Evaluate on two set of states independently (e.g., (a) s' ~ expert-10; (b) s'' ~ expert-100)
3. The 'new'-ness of the state must be larger on s' compared to s'' !!
"""
from __future__ import annotations
import functools
from copy import deepcopy
from pathlib import Path
from typing import List

import os
import numpy as np
import torch
import wandb
from rpi import logger
from rpi.helpers import set_random_seed, to_torch
from rpi.helpers.data import flatten
from rpi.helpers.env import rollout_single_ep
from rpi.nn.empirical_normalization import EmpiricalNormalization
from rpi.policies import (
    GaussianHeadWithStateIndependentCovariance,
    SoftmaxCategoricalHead,
)
from rpi.scripts.sweep.default_args import Args
from rpi.helpers.factory import Factory
from rpi.scripts.train import get_expert
from . import NewStateDetector, extract_states, simple_make_env
from .distance_based_methods import DiscreteSigma, EuclideanDistance, WassersteinDistance


def evaluate(state_detector: NewStateDetector, source_states, target_states):
    # Expose the source states to state_detector
    state_detector.experience(source_states)

    # Evaluate the detector on target_states
    scores = state_detector.batch_evaluate(target_states)
    return np.mean(scores)


def main(
    env_name: str,
    load_steps: List[int],
    save_dir,
    state_detector: NewStateDetector,
    seed: int,
    num_episodes: int = 100,
    max_episode_len: int = 1000,
    save_video: bool = False,
        deterministic_expert: bool = False
):

    set_random_seed(seed)

    # Prepare make_env function
    make_env, state_dim, act_dim, env_id = simple_make_env(env_name, default_seed=seed)

    # Load pretrained experts
    policy_head = GaussianHeadWithStateIndependentCovariance(
        action_size=act_dim,
        var_type="diagonal",
        var_func=lambda x: torch.exp(2 * x),  # Parameterize log std
        var_param_init=0,  # log std = 0 => std = 1
    )
    step2expert = {
        load_step: get_expert(
            state_dim,
            act_dim,
            deepcopy(policy_head),
            Path(Args.experts_dir) / env_id  / f"step_{load_step:06d}.pt",
            obs_normalizer=None,
        )
        for load_step in load_steps
    }

    # Get a mapping from expert step to its rollout trajectories
    # Load from file if exists, otherwise generate the data
    step2episodes = {}
    env = make_env()
    for step, expert in step2expert.items():
        fpath = save_dir / env_id / f'deterministic-{deterministic_expert}' / f"expert-{step:06d}.pt"
        if fpath.exists():
            logger.info(f"Loading episodes from {fpath}...")
            obj = torch.load(fpath)
            episodes = obj["episodes"]
            ep_returns = obj["ep_returns"]
        else:
            # Rollout `num_episodes * 2` episodes per expert and save the trajectories
            logger.info(f"Rolling out expert {step}...")
            episodes = [
                rollout_single_ep(
                    env,
                    functools.partial(expert.act, mode=deterministic_expert),
                    max_episode_len,
                )
                for _ in range(num_episodes * 2)
            ]
            # Get stats
            ep_returns = np.array(
                [sum([tr["reward"] for tr in transitions]) for transitions in episodes]
            )

            logger.info(f"Saving the episodes to {fpath}...")
            fpath.parent.mkdir(mode=0o775, parents=True, exist_ok=True)
            torch.save({"episodes": episodes, "ep_returns": ep_returns}, fpath)

        step2episodes[step] = episodes
        wandb.log(
            {
                "step": step,
                "ep_returns-mean": ep_returns.mean(),
                "ep_returns-stddev": ep_returns.std(),
                "ep_returns-median": np.median(ep_returns),
                "ep_returns-hist": wandb.Histogram(ep_returns.tolist()),
            }
        )

    step_and_episodes = sorted(step2episodes.items())
    _, weakest_ep = step_and_episodes[0]

    logger.info("Exposing states to the state detector...")

    # Expose a state detector on the weakest expert states
    state_detector.experience(weakest_ep[:num_episodes])

    # Evaluate the state-detector on the weakest (test set) and strongest episodes
    logger.info("Evaluating the new states...")
    for step, episode in step_and_episodes:
        # NOTE: Take the latter half of the episode!
        batch_states = extract_states(episode[num_episodes:])
        print("step", step, "num states", batch_states.shape[0])
        scores = state_detector.batch_evaluate(batch_states)
        scores = np.asarray(scores)

        # Report mean, stddev, median (and histogram)
        wandb.log(
            {
                "expert-step": step,
                "eval/mean": scores.mean(),
                "eval/stddev": scores.std(),
                "eval/median": np.median(scores),
                "histogram": wandb.Histogram(scores.tolist()),
            }
        )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("envname", help="Environment name (e.g., dmc:Cheetah-run-v1)")
    parser.add_argument(
        "--load-steps", nargs="+", default=[10, 800], type=int
    )
    parser.add_argument("--save-video", action="store_true")
    parser.add_argument("--use-max-dist", action="store_true")
    parser.add_argument("--deter-exp", action="store_true", help='deterministic expert')
    parser.add_argument("--dev-est", choices=['max_dist', 'std_from_mean', 'std'], default='std')
    args = parser.parse_args()

    print("args", args)

    save_dir = Path("/new-state-detection")

    if "CUDA_VISIBLE_DEVICES" not in os.environ:
        avail_gpus = [0, 1, 2, 3]
        cvd = avail_gpus[0]
        os.environ["CUDA_VISIBLE_DEVICES"] = str(cvd)

    wandb.login()
    wandb.init(
        # Set the project where this run will be logged
        project="alops-new-state-detection",
        config=vars(Args),
    )

    # Ensemble network to predict the next state
    from .ensemble_network import StatePredEnsembleNewStateDetector
    _, state_dim, _, _ = simple_make_env(args.envname)
    state_detector = StatePredEnsembleNewStateDetector(state_dim, deviation_estimator=args.dev_est)

    #########################
    # state_detector = NewStateDetector()

    # # option  {mean_dist, min_dist, mean_min_n_dist, gap_accept_dist}

    # # option="min_dist"
    # state_detector =  EuclideanDistance(min_n=0,threshold_accept=0,option="min_dist", debug=True)

    # # option="mean_dist"
    # state_detector =  EuclideanDistance(min_n=0,threshold_accept=0,option="mean_dist", debug=True)

    # # option="mean_min_n_dist"
    # state_detector =  EuclideanDistance(min_n=2,threshold_accept=0,option="mean_min_n_dist", debug=True)

    # # option="gap_accept_dist"
    # state_detector =  EuclideanDistance(min_n=0,threshold_accept=3,option="gap_accept_dist", debug=True)

    ###########################
    # calculate the sigma for discrete space.
    # horizon = 10000
    # delta = 0.2
    # state_precision = 4
    # beta = 0.01
    # state_detector = DiscreteSigma(horizon, delta, state_precision, beta, True)

    main(
        args.envname,
        args.load_steps,
        save_dir,
        state_detector,
        seed=0,
        save_video=args.save_video,
        deterministic_expert=args.deter_exp
    )
